/*
 * SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 * SPDX-License-Identifier: NVIDIA TensorRT Source Code License Agreement
 *
 * NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
 * property and proprietary rights in and to this material, related
 * documentation and any modifications thereto. Any use, reproduction,
 * disclosure or distribution of this material and related documentation
 * without an express license agreement from NVIDIA CORPORATION or
 * its affiliates is strictly prohibited.
 */

#pragma once

#include <cuda_runtime_api.h>

#include "moe_gemm_kernels.h"

namespace tensorrt_llm {
namespace kernels {
namespace cutlass_kernels {

// Keep in sync with the signature generated by generate_kernels.py
template <typename Arch, typename T, typename WeightType, typename OutputType, typename EpilogueTag,
          tensorrt_llm::TmaWarpSpecializedGroupedGemmInput::EpilogueFusion FUSION,
          typename TileShape, typename ClusterShape, bool BIAS>
void tma_warp_specialized_generic_moe_gemm_kernelLauncher(
    TmaWarpSpecializedGroupedGemmInput hopper_input, int num_experts, int multi_processor_count,
    cudaStream_t stream, int* kernel_occupancy, size_t* workspace_size);

}  // namespace cutlass_kernels
}  // namespace kernels
}  // namespace tensorrt_llm
